{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一.基本原理\n",
    "非负矩阵分解(non-negative matrix factorization，NMF)原理很简单，与SVD将矩阵分解为三个矩阵类似，NMF将矩阵分解为两个小矩阵，比如原始矩阵$A_{m\\times n}$分解为$W_{m\\times k}$与$H_{k\\times n}$的乘积，即：   \n",
    "\n",
    "$$\n",
    "A_{m\\times n}\\simeq W_{m\\times k}H_{k\\times n}\n",
    "$$  \n",
    "\n",
    "这里，要求$A,W,H$中的元素都非负，而参数估计也很简单，最小化如下的平方损失即可：   \n",
    "\n",
    "$$\n",
    "L(A,W,H)=\\frac{1}{2}\\sum_{i=1}^m\\sum_{j=1}^n(A_{ij}-(WH)_{ij})^2=\\frac{1}{2}\\sum_{i=1}^m\\sum_{j=1}^n(A_{ij}-\\sum_{l=1}^kW_{il}H_{lj})^2\n",
    "$$\n",
    "\n",
    "所以：   \n",
    "\n",
    "$$\n",
    "W^*,H^*=arg\\min_{W,H}L(A,W,H)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二.参数估计\n",
    "对于参数估计，采用梯度下降即可，下面推导一下：   \n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial W_{ik}}=\\sum_j(A_{ij}-(WH)_{ij})\\cdot \\frac{\\partial (WH)_{ij}}{\\partial W_{ik}}=\\sum_j(A_{ij}-(WH)_{ij})\\cdot (-H_{kj})=-\\sum_j A_{ij}H_{kj}+\\sum_j(WH)_{ij}H_{kj}=(WHH^T)_{ik}-(AH^T)_{ik}\n",
    "$$  \n",
    "\n",
    "类似地：   \n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial H_{kj}}=(W^TWH)_{kj}-(W^TA)_{kj}\n",
    "$$  \n",
    "\n",
    "所以，梯度下降的更新公式可以表示如下：   \n",
    "\n",
    "$$\n",
    "W_{ik}\\leftarrow W_{ik}+\\alpha_1[(AH^T)_{ik}-(WHH^T)_{ik}]\\\\\n",
    "H_{kj}\\leftarrow H_{kj}+\\alpha_2[(W^TA)_{kj}-(W^TWH)_{kj}]\n",
    "$$  \n",
    "\n",
    "这里，$\\alpha_1>0,\\alpha_2>0$为学习率，如果我们巧妙的设置：  \n",
    "\n",
    "$$\n",
    "\\alpha_1=\\frac{W_{ik}}{(WHH^T)_{ik}}\\\\\n",
    "\\alpha_2=\\frac{H_{kj}}{(W^TWH)_{kj}}\n",
    "$$  \n",
    "\n",
    "那么，迭代公式为：   \n",
    "\n",
    "$$\n",
    "W_{ik}\\leftarrow W_{ik}\\cdot\\frac{(AH^T)_{ik}}{(WHH^T)_{ik}}\\\\\n",
    "H_{kj}\\leftarrow H_{kj}\\cdot\\frac{(W^TA)_{kj}}{(W^TWH)_{kj}}\n",
    "$$  \n",
    "\n",
    "可以发现该迭代公式也很好的满足了我们的约束条件，$W,H$在迭代过程中始终非负"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三.代码实现\n",
    "NMF在NLP中应用也比较多，比如做主题模型...，下面用LDA那一节的样本做测试   \n",
    "\n",
    "#### 准备数据...."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "docs=[\n",
    "    [\"有\",\"微信\",\"红包\",\"的\",\"软件\"],\n",
    "    [\"微信\",\"支付\",\"不行\",\"的\"],\n",
    "    [\"我们\",\"需要\",\"稳定的\",\"微信\",\"支付\",\"接口\"],\n",
    "    [\"申请\",\"公众号\",\"认证\"],\n",
    "    [\"这个\",\"还有\",\"几天\",\"放\",\"垃圾\",\"流量\"],\n",
    "    [\"可以\",\"提供\",\"聚合\",\"支付\",\"系统\"]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2id={}\n",
    "idx=0\n",
    "W=[]\n",
    "for doc in docs:\n",
    "    tmp=[]\n",
    "    for word in doc:\n",
    "        if word in word2id:\n",
    "            tmp.append(word2id[word])\n",
    "        else:\n",
    "            word2id[word]=idx\n",
    "            idx+=1\n",
    "            tmp.append(word2id[word])\n",
    "    W.append(tmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, 1, 2, 3, 4],\n",
       " [1, 5, 6, 3],\n",
       " [7, 8, 9, 1, 5, 10],\n",
       " [11, 12, 13],\n",
       " [14, 15, 16, 17, 18, 19],\n",
       " [20, 21, 22, 5, 23]]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "data = np.zeros(shape=(len(docs), len(word2id)))\n",
    "for idx, w in enumerate(W):\n",
    "    for i in w:\n",
    "        data[idx][i] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 训练模型...."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "from ml_models.decomposition import NMF\n",
    "\n",
    "nmf = NMF(n_components=3,epochs=10)\n",
    "trans = nmf.fit_transform(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 查看效果..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cosine(x1, x2):\n",
    "    return x1.dot(x2) / (np.sqrt(np.sum(np.power(x1, 2))) * np.sqrt(np.sum(np.power(x2, 2))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1844189988582961"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第二句和第三句应该比较近似，因为它们都含有“微信”，“支付”\n",
    "cosine(trans[1],trans[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.16058138087290988"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#而第二句和第四句的相似度显然不如第二句和第三句，因为它们没有完全相同的词\n",
    "cosine(trans[1],trans[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2459676681757762"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#而第一句和第二句都含有“微信”,\"的\"这两个词，所以相似度会比第三,四句高\n",
    "cosine(trans[1],trans[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
